import autograd.numpy as np
import numpy.testing as np_testing
import pytest
from scipy.linalg import eigvalsh, expm, logm

from pymanopt.manifolds import (
    HermitianPositiveDefinite,
    SpecialHermitianPositiveDefinite,
    SymmetricPositiveDefinite,
)
from pymanopt.tools.multi import (
    multiexpm,
    multihconj,
    multiherm,
    multilogm,
    multisym,
    multitransp,
)


def geodesic(point_a, point_b, alpha):
    if alpha < 0 or 1 < alpha:
        raise ValueError("Exponent must be in [0,1]")
    c = np.linalg.cholesky(point_a)
    c_inv = np.linalg.inv(c)
    log_cbc = multilogm(
        c_inv @ point_b @ multihconj(c_inv),
        positive_definite=True,
    )
    powm = multiexpm(alpha * log_cbc, symmetric=False)
    return c @ powm @ multihconj(c)


class TestSingleSymmetricPositiveDefiniteManifold:
    @pytest.fixture(autouse=True)
    def setup(self):
        self.n = n = 15
        self.manifold = SymmetricPositiveDefinite(n)

    def test_random_point(self):
        # Just test that rand returns a point on the manifold and two
        # different matrices generated by rand aren't too close together
        n = self.n
        manifold = self.manifold
        x = manifold.random_point()

        assert np.shape(x) == (n, n)

        # Check symmetry
        np_testing.assert_allclose(x, multisym(x))

        # Check positivity of eigenvalues
        w = np.linalg.eigvalsh(x)
        assert (w > [0]).all()

    def test_dist(self):
        manifold = self.manifold
        x = manifold.random_point()
        y = manifold.random_point()
        z = manifold.random_point()

        # Test separability
        np_testing.assert_almost_equal(manifold.dist(x, x), 0.0)

        # Test symmetry
        np_testing.assert_almost_equal(
            manifold.dist(x, y), manifold.dist(y, x)
        )

        # Test triangle inequality
        assert manifold.dist(x, y) <= manifold.dist(x, z) + manifold.dist(z, y)

        # Test alternative implementation (see equation (6.14) in [Bha2007]).
        d = np.sqrt((np.log(eigvalsh(x, y)) ** 2).sum())
        np_testing.assert_almost_equal(manifold.dist(x, y), d)

        # Test exponential metric increasing property
        # (see equation (6.8) in [Bha2007]).
        assert manifold.dist(x, y) >= np.linalg.norm(logm(x) - logm(y))

        # check that dist is consistent with log
        np_testing.assert_almost_equal(
            manifold.dist(x, y), manifold.norm(x, manifold.log(x, y))
        )

        # Test invariance under inversion
        np_testing.assert_almost_equal(
            manifold.dist(x, y),
            manifold.dist(np.linalg.inv(y), np.linalg.inv(x)),
        )

        # Test congruence-invariance
        a = np.random.normal(size=(self.n, self.n))  # must be invertible
        axa = a @ x @ multitransp(a)
        aya = a @ y @ multitransp(a)
        np_testing.assert_almost_equal(
            manifold.dist(x, y), manifold.dist(axa, aya)
        )

        # Test proportionality (see equation (6.12) in [Bha2007]).
        alpha = np.random.uniform()
        np_testing.assert_almost_equal(
            manifold.dist(x, geodesic(x, y, alpha)),
            alpha * manifold.dist(x, y),
        )

    def test_exp(self):
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        e = expm(np.linalg.solve(x, u))

        np_testing.assert_allclose(x @ e, manifold.exp(x, u))
        u = u * 1e-6
        np_testing.assert_allclose(manifold.exp(x, u), x + u)

    def test_random_tangent_vector(self):
        # Just test that random_tangent_vector returns an element of the tangent space
        # with norm 1 and that two random_tangent_vectors are different.
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        v = manifold.random_tangent_vector(x)
        np_testing.assert_allclose(multiherm(u), u)
        np_testing.assert_almost_equal(1, manifold.norm(x, u))
        assert np.linalg.norm(u - v) > 1e-3

    def test_norm(self):
        manifold = self.manifold
        x = manifold.random_point()
        np_testing.assert_almost_equal(
            manifold.norm(np.eye(self.n), x), np.linalg.norm(x)
        )

    def test_exp_log_inverse(self):
        manifold = self.manifold
        x = manifold.random_point()
        y = manifold.random_point()
        u = manifold.log(x, y)
        np_testing.assert_allclose(manifold.exp(x, u), y)

    def test_log_exp_inverse(self):
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        y = manifold.exp(x, u)
        np_testing.assert_allclose(manifold.log(x, y), u)


class TestSingleHermitianPositiveDefiniteManifold(
    TestSingleSymmetricPositiveDefiniteManifold
):
    @pytest.fixture(autouse=True)
    def setup(self):
        self.n = n = 15
        self.manifold = HermitianPositiveDefinite(n)

    def test_dim(self):
        manifold = self.manifold
        n = self.n
        np_testing.assert_equal(manifold.dim, n * (n + 1))

    def test_random_point(self):
        # Just test that random_point returns a point on the manifold and two
        # different matrices generated by random_point aren't too close together
        n = self.n
        manifold = self.manifold
        x = manifold.random_point()

        assert np.shape(x) == (n, n)
        assert x.dtype == complex

        # Check symmetry
        np_testing.assert_allclose(x, multiherm(x))

        # Check positivity of eigenvalues
        w = np.linalg.eigvalsh(x)
        assert (w > [0]).all()

    def test_random_tangent_vector(self):
        # Just test that randvec returns an element of the tangent space
        # with norm 1 and that two randvecs are different.
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        v = manifold.random_tangent_vector(x)
        np_testing.assert_allclose(multiherm(u), u)
        np_testing.assert_almost_equal(1, manifold.norm(x, u))
        assert u.shape == (self.n, self.n)
        assert u.dtype == complex
        assert np.linalg.norm(u - v) > 1e-3


class TestMultiSymmetricPositiveDefiniteManifold:
    @pytest.fixture(autouse=True)
    def setup(self):
        self.n = n = 10
        self.k = k = 3
        self.manifold = SymmetricPositiveDefinite(n, k=k)

    def test_dim(self):
        manifold = self.manifold
        n = self.n
        k = self.k
        np_testing.assert_equal(manifold.dim, 0.5 * k * n * (n + 1))

    def test_typical_dist(self):
        manifold = self.manifold
        np_testing.assert_equal(manifold.typical_dist, np.sqrt(manifold.dim))

    def test_dist(self):
        # n = self.n
        manifold = self.manifold
        x = manifold.random_point()
        y = manifold.random_point()
        z = manifold.random_point()

        # Test separability
        np_testing.assert_almost_equal(manifold.dist(x, x), 0.0)

        # Test symmetry
        np_testing.assert_almost_equal(
            manifold.dist(x, y), manifold.dist(y, x)
        )

        # Test triangle inequality
        assert manifold.dist(x, y) <= manifold.dist(x, z) + manifold.dist(z, y)

        # Test exponential metric increasing property
        # (see equation (6.8) in [Bha2007]).
        logx, logy = multilogm(x), multilogm(y)
        assert manifold.dist(x, y) >= np.linalg.norm(logx - logy)

        # check that dist is consistent with log
        np_testing.assert_almost_equal(
            manifold.dist(x, y), manifold.norm(x, manifold.log(x, y))
        )

        # Test invariance under inversion
        np_testing.assert_almost_equal(
            manifold.dist(x, y),
            manifold.dist(np.linalg.inv(y), np.linalg.inv(x)),
        )

        # Test congruence-invariance
        a = np.random.normal(size=(self.n, self.n))  # must be invertible
        axa = a @ x @ multitransp(a)
        aya = a @ y @ multitransp(a)
        np_testing.assert_almost_equal(
            manifold.dist(x, y), manifold.dist(axa, aya)
        )

        # Test proportionality (see equation (6.12) in [Bha2007]).
        alpha = np.random.uniform()
        np_testing.assert_almost_equal(
            manifold.dist(x, geodesic(x, y, alpha)),
            alpha * manifold.dist(x, y),
        )

    def test_inner_product(self):
        manifold = self.manifold
        k = self.k
        n = self.n
        x = manifold.random_point()
        a, b = np.random.normal(size=(2, k, n, n))
        np_testing.assert_almost_equal(
            np.tensordot(a, b.transpose((0, 2, 1)), axes=a.ndim),
            manifold.inner_product(x, x @ a, x @ b),
        )

    def test_projection(self):
        manifold = self.manifold
        x = manifold.random_point()
        a = np.random.normal(size=(self.k, self.n, self.n))
        np_testing.assert_allclose(manifold.projection(x, a), multiherm(a))

    def test_euclidean_to_riemannian_gradient(self):
        manifold = self.manifold
        x = manifold.random_point()
        u = np.random.normal(size=(self.k, self.n, self.n))
        np_testing.assert_allclose(
            manifold.euclidean_to_riemannian_gradient(x, u),
            x @ multiherm(u) @ x,
        )

    def test_euclidean_to_riemannian_hessian(self):
        # Use manopt's slow method
        manifold = self.manifold
        n = self.n
        k = self.k
        x = manifold.random_point()
        egrad, ehess = np.random.normal(size=(2, k, n, n))
        u = manifold.random_tangent_vector(x)

        Hess = x @ multiherm(ehess) @ x + 2 * multiherm(
            u @ multiherm(egrad) @ x
        )

        # Correction factor for the non-constant metric
        Hess = Hess - multiherm(u @ multiherm(egrad) @ x)
        np_testing.assert_almost_equal(
            Hess, manifold.euclidean_to_riemannian_hessian(x, egrad, ehess, u)
        )

    def test_norm(self):
        manifold = self.manifold
        x = manifold.random_point()
        Id = np.array(self.k * [np.eye(self.n)])
        np_testing.assert_almost_equal(manifold.norm(Id, x), np.linalg.norm(x))

    def test_random_point(self):
        # Just test that rand returns a point on the manifold and two
        # different matrices generated by rand aren't too close together
        k = self.k
        n = self.n
        manifold = self.manifold
        x = manifold.random_point()

        assert np.shape(x) == (k, n, n)

        # Check symmetry
        np_testing.assert_allclose(x, multisym(x))

        # Check positivity of eigenvalues
        w = np.linalg.eigvalsh(x)
        assert (w > [[0]]).all()

    def test_random_tangent_vector(self):
        # Just test that random_tangent_vector returns an element of the tangent space
        # with norm 1 and that two random_tangent_vectors are different.
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        v = manifold.random_tangent_vector(x)
        np_testing.assert_allclose(multisym(u), u)
        np_testing.assert_almost_equal(1, manifold.norm(x, u))
        assert np.linalg.norm(u - v) > 1e-3

    def test_transport(self):
        manifold = self.manifold
        x = manifold.random_point()
        y = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        u_transp = manifold.transport(x, y, u)
        u_transp_proj = manifold.projection(y, u_transp)
        np_testing.assert_allclose(u_transp, u_transp_proj)

    def test_exp(self):
        # Test against manopt implementation, test that for small vectors
        # exp(x, u) = x + u.
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        e = np.zeros_like(x)
        for i in range(self.k):
            e[i] = expm(np.linalg.solve(x[i], u[i]))
        np_testing.assert_allclose(x @ e, manifold.exp(x, u))
        u = u * 1e-6
        np_testing.assert_allclose(manifold.exp(x, u), x + u)

    def test_retraction(self):
        # Check that result is on manifold and for small vectors
        # retr(x, u) = x + u.
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        y = manifold.retraction(x, u)

        assert np.shape(y) == (self.k, self.n, self.n)
        # Check symmetry
        np_testing.assert_allclose(y, multiherm(y))

        # Check positivity of eigenvalues
        w = np.linalg.eigvalsh(y)
        assert (w > [[0]]).all()

        u = u * 1e-6
        np_testing.assert_allclose(manifold.retraction(x, u), x + u)

    def test_exp_log_inverse(self):
        manifold = self.manifold
        x = manifold.random_point()
        y = manifold.random_point()
        u = manifold.log(x, y)
        np_testing.assert_allclose(manifold.exp(x, u), y)

    def test_log_exp_inverse(self):
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        y = manifold.exp(x, u)
        np_testing.assert_allclose(manifold.log(x, y), u)


class TestMultiHermitianPositiveDefiniteManifold(
    TestMultiSymmetricPositiveDefiniteManifold
):
    @pytest.fixture(autouse=True)
    def setup(self):
        self.n = n = 10
        self.k = k = 3
        self.manifold = HermitianPositiveDefinite(n, k=k)

    def test_dim(self):
        manifold = self.manifold
        n = self.n
        k = self.k
        np_testing.assert_equal(manifold.dim, k * n * (n + 1))

    def test_random_point(self):
        # Just test that rand returns a point on the manifold and two
        # different matrices generated by rand aren't too close together
        k = self.k
        n = self.n
        manifold = self.manifold
        x = manifold.random_point()

        assert np.shape(x) == (k, n, n)
        assert x.dtype == complex

        # Check symmetry
        np_testing.assert_allclose(x, multiherm(x))

        # Check positivity of eigenvalues
        w = np.linalg.eigvalsh(x)
        assert (w > [[0]]).all()

    def test_random_tangent_vector(self):
        # Just test that randvec returns an element of the tangent space
        # with norm 1 and that two randvecs are different.
        k = self.k
        n = self.n
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        v = manifold.random_tangent_vector(x)
        np_testing.assert_allclose(multiherm(u), u)
        np_testing.assert_almost_equal(1, manifold.norm(x, u))
        assert u.shape == (k, n, n)
        assert np.linalg.norm(u - v) > 1e-3


class TestSingleSpecialHermitianPositiveDefiniteManifold(
    TestSingleHermitianPositiveDefiniteManifold
):
    @pytest.fixture(autouse=True)
    def setup(self):
        self.n = n = 10
        self.k = k = 1
        self.manifold = SpecialHermitianPositiveDefinite(n, k=k)

    def test_dim(self):
        manifold = self.manifold
        n = self.n
        np_testing.assert_equal(manifold.dim, n * (n + 1) - 1)

    def test_random_point(self):
        # Just test that rand returns a point on the manifold and two
        # different matrices generated by rand aren't too close together
        n = self.n
        manifold = self.manifold
        x = manifold.random_point()
        y = manifold.random_point()

        assert np.shape(x) == (n, n)
        assert x.dtype == complex

        # Check symmetry
        np_testing.assert_allclose(x, multiherm(x))

        # Check positivity of eigenvalues
        w = np.linalg.eigvalsh(x)
        assert (w > [[0]]).all()

        # Check unit determinant
        d = np.real(np.linalg.det(x))
        np_testing.assert_allclose(d, 1)

        # Check randomness
        assert np.linalg.norm(x - y) > 1e-3

    def test_random_tangent_vector(self):
        # Just test that randvec returns an element of the tangent space
        # with norm 1 and that two randvecs are different.
        manifold = self.manifold
        n = self.n
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        v = manifold.random_tangent_vector(x)

        assert np.shape(x) == (n, n)
        assert x.dtype == complex

        np_testing.assert_allclose(multiherm(u), u)

        t = np.real(np.trace(np.linalg.solve(x, u)))
        np_testing.assert_almost_equal(t, 0)

        np_testing.assert_almost_equal(1, manifold.norm(x, u))

        assert np.linalg.norm(u - v) > 1e-3

    def test_projection(self):
        manifold = self.manifold
        x = manifold.random_point()
        a = np.random.randn(self.n, self.n) + 1j * np.random.randn(
            self.n, self.n
        )
        p = manifold.projection(x, a)

        assert np.shape(p) == (self.n, self.n)

        np_testing.assert_allclose(p, multiherm(p))

        t = np.real(np.trace(np.linalg.solve(x, p)))
        np_testing.assert_almost_equal(t, 0)

        np_testing.assert_allclose(p, manifold.projection(x, p))

    def test_euclidean_to_riemannian_gradient(self):
        manifold = self.manifold
        x = manifold.random_point()
        u = np.random.normal(size=(self.k, self.n, self.n))
        np_testing.assert_allclose(
            manifold.euclidean_to_riemannian_gradient(x, u),
            manifold.projection(x, x @ u @ x),
        )

    def test_exp(self):
        # Test against manopt implementation, test that for small vectors
        # exp(x, u) = x + u.
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        e = manifold.exp(x, u)
        assert np.shape(e) == (self.n, self.n)

        # Check symmetry
        np_testing.assert_allclose(e, multiherm(e))

        # Check positivity of eigenvalues
        w = np.linalg.eigvalsh(e)
        assert (w > [[0]]).all()

        # Check unit determinant
        d = np.linalg.det(e)
        np_testing.assert_allclose(d, 1)

        u = u * 1e-6
        np_testing.assert_allclose(manifold.exp(x, u), x + u)

    def test_retraction(self):
        # Check that result is on manifold and for small vectors
        # retraction(x, u) = x + u.
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        y = manifold.retraction(x, u)

        assert np.shape(y) == (self.n, self.n)
        # Check symmetry
        np_testing.assert_allclose(y, multiherm(y))

        # Check positivity of eigenvalues
        w = np.linalg.eigvalsh(y)
        assert (w > [[0]]).all()

        # Check unit determinant
        d = np.linalg.det(y)
        np_testing.assert_allclose(d, 1)

        u = u * 1e-6
        np_testing.assert_allclose(manifold.retraction(x, u), x + u)


class TestMultiSpecialHermitianPositiveDefiniteManifold(
    TestMultiHermitianPositiveDefiniteManifold
):
    @pytest.fixture(autouse=True)
    def setup(self):
        self.n = n = 10
        self.k = k = 3
        self.manifold = SpecialHermitianPositiveDefinite(n, k=k)

    def test_dim(self):
        manifold = self.manifold
        n = self.n
        k = self.k
        np_testing.assert_equal(manifold.dim, k * (n * (n + 1) - 1))

    def test_random_point(self):
        # Just test that rand returns a point on the manifold and two
        # different matrices generated by rand aren't too close together
        k = self.k
        n = self.n
        manifold = self.manifold
        x = manifold.random_point()
        y = manifold.random_point()

        assert np.shape(x) == (k, n, n)
        assert x.dtype == complex

        # Check symmetry
        np_testing.assert_allclose(x, multiherm(x))

        # Check positivity of eigenvalues
        w = np.linalg.eigvalsh(x)
        assert (w > [[0]]).all()

        # Check unit determinant
        d = np.real(np.linalg.det(x))
        np_testing.assert_allclose(d, 1)

        # Check randomness
        assert np.linalg.norm(x - y) > 1e-3

    def test_random_tangent_vector(self):
        # Just test that randvec returns an element of the tangent space
        # with norm 1 and that two randvecs are different.
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        v = manifold.random_tangent_vector(x)

        np_testing.assert_allclose(multiherm(u), u)

        t = np.empty(manifold._k, dtype=complex)
        temp = np.linalg.solve(x, u)
        for i in range(manifold._k):
            t[i] = np.real(np.trace(temp[i, :, :]))
        np_testing.assert_allclose(t, 0, atol=1e-7)

        np_testing.assert_almost_equal(1, manifold.norm(x, u))

        assert np.linalg.norm(u - v) > 1e-3

    def test_projection(self):
        manifold = self.manifold
        x = manifold.random_point()
        a = np.random.randn(self.k, self.n, self.n) + 1j * np.random.randn(
            self.k, self.n, self.n
        )
        p = manifold.projection(x, a)

        np_testing.assert_allclose(p, multiherm(p))

        t = np.ones(manifold._k, dtype=complex)
        temp = np.linalg.solve(x, p)
        for i in range(manifold._k):
            t[i] = np.real(np.trace(temp[i, :, :]))
        np_testing.assert_allclose(t, 0, atol=1e-7)

        np_testing.assert_allclose(p, manifold.projection(x, p))

    def test_euclidean_to_riemannian_gradient(self):
        manifold = self.manifold
        x = manifold.random_point()
        u = np.random.normal(size=(self.k, self.n, self.n))
        np_testing.assert_allclose(
            manifold.euclidean_to_riemannian_gradient(x, u),
            manifold.projection(x, x @ u @ x),
        )

    def test_euclidean_to_riemannian_hessian(self):
        pass

    def test_exp(self):
        # Test against manopt implementation, test that for small vectors
        # exp(x, u) = x + u.
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        e = manifold.exp(x, u)

        # Check symmetry
        np_testing.assert_allclose(e, multiherm(e))

        # Check positivity of eigenvalues
        w = np.linalg.eigvalsh(e)
        assert (w > [[0]]).all()

        # Check unit determinant
        d = np.linalg.det(e)
        np_testing.assert_allclose(d, 1)

        u = u * 1e-6
        np_testing.assert_allclose(manifold.exp(x, u), x + u)

    def test_retraction(self):
        # Check that result is on manifoldifold and for small vectors
        # retraction(x, u) = x + u.
        manifold = self.manifold
        x = manifold.random_point()
        u = manifold.random_tangent_vector(x)
        y = manifold.retraction(x, u)

        assert np.shape(y) == (self.k, self.n, self.n)
        # Check symmetry
        np_testing.assert_allclose(y, multiherm(y))

        # Check positivity of eigenvalues
        w = np.linalg.eigvalsh(y)
        assert (w > [[0]]).all()

        # Check unit determinant
        d = np.linalg.det(y)
        np_testing.assert_allclose(d, 1)

        u = u * 1e-6
        np_testing.assert_allclose(manifold.retraction(x, u), x + u)
